import numpy as np
import scipy
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
import os
from utils import df1, df2, LMC
import argparse


if __name__ == '__main__':
    np.random.seed(0)

    parser = argparse.ArgumentParser()
    parser.add_argument("--potential")
    args = parser.parse_args()

    config = {
        'step size': 0.1,
        'num samples': int(1e4),
        'grad potential': None,
        'dimension': -1,
        'initial condition': 1,
        'T': 10,
        'stats function': None,
    }

    for d in [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]:
        config['dimension'] = d

        if args.potential == 'log-sum-exp':
            config['grad potential'] = df1
        elif args.potential == 'cosine':
            config['grad potential'] = lambda x: df2(x, d)

        true_mean = np.load(f'test_function_{args.potential}/benchmark/d={d}.npy').mean(axis=0)
        config['stats function'] = lambda sample: np.linalg.norm( sample.mean(axis=0) - true_mean )

        start = time.time()
        hist = LMC(config)
        end = time.time()
        print(f'dimension {d} finished: {end - start:.3f}s elapsed.')

        np.save(f'test_function_{args.potential}/experiment/d={d}.npy', hist)

        